from collections import Counter
from tqdm import tqdm
import pandas as pd
import numpy as np
import plotly.express as px
from imblearn.datasets import fetch_datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
# from tabpfn import TabPFNClassifier
datasets = fetch_datasets()
datasets
OrderedDict([('ecoli',
{'data': array([[0.49, 0.29, 0.48, ..., 0.56, 0.24, 0.35],
[0.07, 0.4 , 0.48, ..., 0.54, 0.35, 0.44],
[0.56, 0.4 , 0.48, ..., 0.49, 0.37, 0.46],
...,
[0.61, 0.6 , 0.48, ..., 0.44, 0.39, 0.38],
[0.59, 0.61, 0.48, ..., 0.42, 0.42, 0.37],
[0.74, 0.74, 0.48, ..., 0.31, 0.53, 0.52]]),
'target': array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]),
'DESCR': 'ecoli'}),
('optical_digits',
{'data': array([[ 0., 1., 6., ..., 1., 0., 0.],
[ 0., 0., 10., ..., 3., 0., 0.],
[ 0., 0., 8., ..., 0., 0., 0.],
...,
[ 0., 0., 1., ..., 6., 0., 0.],
[ 0., 0., 2., ..., 12., 0., 0.],
[ 0., 0., 10., ..., 12., 1., 0.]]),
'target': array([-1, -1, -1, ..., 1, -1, 1]),
'DESCR': 'optical_digits'}),
('satimage',
{'data': array([[ 92., 115., 120., ..., 107., 113., 87.],
[ 84., 102., 106., ..., 99., 104., 79.],
[ 84., 102., 102., ..., 99., 104., 79.],
...,
[ 56., 68., 91., ..., 83., 92., 74.],
[ 56., 68., 87., ..., 83., 92., 70.],
[ 60., 71., 91., ..., 79., 108., 92.]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'satimage'}),
('pen_digits',
{'data': array([[ 47., 100., 27., ..., 90., 40., 98.],
[ 0., 89., 27., ..., 2., 100., 6.],
[ 0., 57., 31., ..., 25., 16., 0.],
...,
[ 56., 100., 27., ..., 93., 38., 93.],
[ 19., 100., 0., ..., 97., 10., 81.],
[ 38., 100., 37., ..., 26., 65., 0.]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'pen_digits'}),
('abalone',
{'data': array([[0. , 0. , 1. , ..., 0.2245, 0.101 , 0.15 ],
[0. , 0. , 1. , ..., 0.0995, 0.0485, 0.07 ],
[1. , 0. , 0. , ..., 0.2565, 0.1415, 0.21 ],
...,
[0. , 0. , 1. , ..., 0.5255, 0.2875, 0.308 ],
[1. , 0. , 0. , ..., 0.531 , 0.261 , 0.296 ],
[0. , 0. , 1. , ..., 0.9455, 0.3765, 0.495 ]]),
'target': array([-1, 1, -1, ..., -1, -1, -1]),
'DESCR': 'abalone'}),
('sick_euthyroid',
{'data': array([[ 72., 0., 1., ..., 87., 1., 0.],
[ 45., 1., 0., ..., 112., 1., 0.],
[ 64., 1., 0., ..., 123., 1., 0.],
...,
[ 58., 1., 0., ..., 95., 1., 0.],
[ 29., 1., 0., ..., 98., 1., 0.],
[ 56., 1., 0., ..., 143., 1., 0.]]),
'target': array([ 1, 1, 1, ..., -1, -1, -1]),
'DESCR': 'sick_euthyroid'}),
('spectrometer',
{'data': array([[ 4119.1675 , 4897.299 , 4163.969 , ..., 1392.4745 ,
1278.9945 , 1440.482 ],
[ 7660.999 , 7906.784 , 7821.8984 , ..., 7015.5747 ,
6962.22 , 6263.44 ],
[ 3196.4287 , 3013.9722 , 3003.149 , ..., 5954.388 ,
5337.8887 , 4638.5244 ],
...,
[15375.604 , 14542.233 , 13849.163 , ..., 824.46655,
703.36536, 649.1576 ],
[11814.46 , 12896.945 , 13033.315 , ..., 1677.9924 ,
1601.7352 , 1470.1552 ],
[14268.037 , 12925.045 , 12433.298 , ..., 702.66327,
681.01196, 670.62964]]),
'target': array([-1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1,
-1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1,
1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, -1, -1,
-1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1,
1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, 1, 1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1,
-1, -1, -1, -1]),
'DESCR': 'spectrometer'}),
('car_eval_34',
{'data': array([[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 0., 0., 1.],
[0., 0., 0., ..., 1., 0., 0.],
...,
[0., 1., 0., ..., 0., 1., 0.],
[0., 1., 0., ..., 0., 0., 1.],
[0., 1., 0., ..., 1., 0., 0.]]),
'target': array([-1, -1, -1, ..., -1, 1, 1]),
'DESCR': 'car_eval_34'}),
('isolet',
{'data': array([[-0.4394, -0.093 , 0.1718, ..., 0.641 , 0.5898, -0.4872],
[-0.4348, -0.1198, 0.2474, ..., 0.4318, 0.4546, -0.091 ],
[-0.233 , 0.2124, 0.5014, ..., 0.254 , 0.1588, -0.4762],
...,
[-0.6696, -0.373 , 0.1584, ..., 0.0728, 0.0728, -0.5818],
[-0.5764, -0.1764, 0.5106, ..., 0.3044, -0.0434, -0.5 ],
[-0.6624, -0.3334, 0.3666, ..., -0.0894, -0.1708, -0.317 ]]),
'target': array([ 1, 1, 1, ..., -1, -1, -1]),
'DESCR': 'isolet'}),
('us_crime',
{'data': array([[0.19, 0.33, 0.02, ..., 0.26, 0.2 , 0.32],
[0. , 0.16, 0.12, ..., 0.12, 0.45, 0. ],
[0. , 0.42, 0.49, ..., 0.21, 0.02, 0. ],
...,
[0.16, 0.37, 0.25, ..., 0.32, 0.18, 0.91],
[0.08, 0.51, 0.06, ..., 0.38, 0.33, 0.22],
[0.2 , 0.78, 0.14, ..., 0.3 , 0.05, 1. ]]),
'target': array([-1, 1, -1, ..., -1, -1, -1]),
'DESCR': 'us_crime'}),
('yeast_ml8',
{'data': array([[ 0.0937 , 0.139771, 0.062774, ..., -0.042402, 0.118473,
0.125632],
[-0.022711, -0.050504, -0.035691, ..., -0.014191, 0.022783,
0.123785],
[-0.090407, 0.021198, 0.208712, ..., -0.063378, -0.084181,
-0.034402],
...,
[ 0.2416 , 0.127602, -0.033072, ..., -0.038713, -0.026947,
0.00562 ],
[ 0.097274, 0.088109, 0.161101, ..., -0.019985, 0.280843,
0.143382],
[-0.001043, 0.030495, 0.007199, ..., 0.006505, -0.041307,
-0.146233]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'yeast_ml8'}),
('scene',
{'data': array([[0.646467 , 0.666435 , 0.685047 , ..., 0.247298 , 0.0140249 ,
0.0297093 ],
[0.770156 , 0.767255 , 0.761053 , ..., 0.137833 , 0.0826722 ,
0.0363203 ],
[0.793984 , 0.772096 , 0.76182 , ..., 0.0511252 , 0.112506 ,
0.0839236 ],
...,
[0.952281 , 0.944987 , 0.905556 , ..., 0.0319002 , 0.0175471 ,
0.0197344 ],
[0.88399 , 0.899004 , 0.901019 , ..., 0.256158 , 0.226332 ,
0.22307 ],
[0.974915 , 0.866425 , 0.818144 , ..., 0.0051313 , 0.0250591 ,
0.00403332]]),
'target': array([ 1, 1, -1, ..., -1, -1, -1]),
'DESCR': 'scene'}),
('libras_move',
{'data': array([[0.79691, 0.38194, 0.79691, ..., 0.3125 , 0.6383 , 0.29398],
[0.67892, 0.27315, 0.68085, ..., 0.69213, 0.17215, 0.69213],
[0.72147, 0.23611, 0.7234 , ..., 0.2662 , 0.78143, 0.27778],
...,
[0.61122, 0.75926, 0.61122, ..., 0.52083, 0.44487, 0.5162 ],
[0.65957, 0.79167, 0.65764, ..., 0.52546, 0.54159, 0.52083],
[0.64023, 0.71991, 0.64217, ..., 0.49537, 0.52031, 0.49306]]),
'target': array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1]),
'DESCR': 'libras_move'}),
('thyroid_sick',
{'data': array([[41., 1., 0., ..., 1., 0., 0.],
[23., 1., 0., ..., 0., 0., 0.],
[46., 0., 1., ..., 0., 0., 0.],
...,
[74., 1., 0., ..., 0., 0., 0.],
[72., 0., 1., ..., 0., 0., 1.],
[64., 1., 0., ..., 0., 0., 0.]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'thyroid_sick'}),
('coil_2000',
{'data': array([[33., 1., 3., ..., 0., 0., 0.],
[37., 1., 2., ..., 0., 0., 0.],
[37., 1., 2., ..., 0., 0., 0.],
...,
[36., 1., 2., ..., 0., 1., 0.],
[33., 1., 3., ..., 0., 0., 0.],
[ 8., 1., 2., ..., 0., 0., 0.]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'coil_2000'}),
('arrhythmia',
{'data': array([[ 75. , 0. , 190. , ..., 2.9, 23.3, 49.4],
[ 56. , 1. , 165. , ..., 2.1, 20.4, 38.8],
[ 54. , 0. , 172. , ..., 3.4, 12.3, 49. ],
...,
[ 36. , 0. , 166. , ..., 1. , -44.2, -33.2],
[ 32. , 1. , 155. , ..., 2.4, 25. , 46.6],
[ 78. , 1. , 160. , ..., 1.6, 21.3, 32.8]]),
'target': array([-1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1,
-1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1,
-1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1,
-1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]),
'DESCR': 'arrhythmia'}),
('solar_flare_m0',
{'data': array([[0., 1., 0., ..., 1., 1., 0.],
[0., 0., 1., ..., 1., 1., 0.],
[0., 1., 0., ..., 1., 0., 1.],
...,
[0., 1., 0., ..., 1., 0., 1.],
[0., 0., 0., ..., 1., 0., 1.],
[1., 0., 0., ..., 1., 0., 1.]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'solar_flare_m0'}),
('oil',
{'data': array([[1.000000e+00, 2.558000e+03, 1.506090e+03, ..., 3.324319e+04,
6.574000e+01, 7.950000e+00],
[2.000000e+00, 2.232500e+04, 7.911000e+01, ..., 5.157204e+04,
6.573000e+01, 6.260000e+00],
[3.000000e+00, 1.150000e+02, 1.449850e+03, ..., 3.169284e+04,
6.581000e+01, 7.840000e+00],
...,
[2.020000e+02, 1.400000e+01, 2.514000e+01, ..., 2.153050e+03,
6.591000e+01, 6.120000e+00],
[2.030000e+02, 1.000000e+01, 9.600000e+01, ..., 2.421430e+03,
6.597000e+01, 6.320000e+00],
[2.040000e+02, 1.100000e+01, 7.730000e+00, ..., 3.782680e+03,
6.565000e+01, 6.260000e+00]]),
'target': array([ 1, -1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1,
-1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1,
1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1,
-1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1,
-1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1]),
'DESCR': 'oil'}),
('car_eval_4',
{'data': array([[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 0., 0., 1.],
[0., 0., 0., ..., 1., 0., 0.],
...,
[0., 1., 0., ..., 0., 1., 0.],
[0., 1., 0., ..., 0., 0., 1.],
[0., 1., 0., ..., 1., 0., 0.]]),
'target': array([-1, -1, -1, ..., -1, -1, 1]),
'DESCR': 'car_eval_4'}),
('wine_quality',
{'data': array([[ 7. , 0.27, 0.36, ..., 3. , 0.45, 8.8 ],
[ 6.3 , 0.3 , 0.34, ..., 3.3 , 0.49, 9.5 ],
[ 8.1 , 0.28, 0.4 , ..., 3.26, 0.44, 10.1 ],
...,
[ 6.5 , 0.24, 0.19, ..., 2.99, 0.46, 9.4 ],
[ 5.5 , 0.29, 0.3 , ..., 3.34, 0.38, 12.8 ],
[ 6. , 0.21, 0.38, ..., 3.26, 0.32, 11.8 ]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'wine_quality'}),
('letter_img',
{'data': array([[ 2., 8., 3., ..., 8., 0., 8.],
[ 5., 12., 3., ..., 8., 4., 10.],
[ 4., 11., 6., ..., 7., 3., 9.],
...,
[ 6., 9., 6., ..., 12., 2., 4.],
[ 2., 3., 4., ..., 9., 5., 8.],
[ 4., 9., 6., ..., 7., 2., 8.]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'letter_img'}),
('yeast_me2',
{'data': array([[0.58, 0.61, 0.47, ..., 0. , 0.48, 0.22],
[0.43, 0.67, 0.48, ..., 0. , 0.53, 0.22],
[0.64, 0.62, 0.49, ..., 0. , 0.53, 0.22],
...,
[0.67, 0.57, 0.36, ..., 0. , 0.56, 0.22],
[0.43, 0.4 , 0.6 , ..., 0. , 0.53, 0.39],
[0.65, 0.54, 0.54, ..., 0. , 0.53, 0.22]]),
'target': array([-1, -1, -1, ..., 1, -1, -1]),
'DESCR': 'yeast_me2'}),
('webpage',
{'data': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 1., 0., 0.],
...,
[0., 0., 0., ..., 1., 1., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.]]),
'target': array([-1, -1, -1, ..., 1, 1, 1]),
'DESCR': 'webpage'}),
('ozone_level',
{'data': array([[ 8.0000e-01, 1.8000e+00, 2.4000e+00, ..., 1.0330e+04,
-5.5000e+01, 0.0000e+00],
[ 2.8000e+00, 3.2000e+00, 3.3000e+00, ..., 1.0275e+04,
-5.5000e+01, 0.0000e+00],
[ 2.9000e+00, 2.8000e+00, 2.6000e+00, ..., 1.0235e+04,
-4.0000e+01, 0.0000e+00],
...,
[ 8.0000e-01, 8.0000e-01, 1.2000e+00, ..., 1.0275e+04,
-3.5000e+01, 0.0000e+00],
[ 1.3000e+00, 9.0000e-01, 1.5000e+00, ..., 1.0245e+04,
-3.0000e+01, 5.0000e-02],
[ 1.5000e+00, 1.3000e+00, 1.8000e+00, ..., 1.0220e+04,
-2.5000e+01, 0.0000e+00]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'ozone_level'}),
('mammography',
{'data': array([[ 0.23001961, 5.0725783 , -0.27606055, 0.83244412, -0.37786573,
0.4803223 ],
[ 0.15549112, -0.16939038, 0.67065219, -0.85955255, -0.37786573,
-0.94572324],
[-0.78441482, -0.44365372, 5.6747053 , -0.85955255, -0.37786573,
-0.94572324],
...,
[ 1.2049878 , 1.7637238 , -0.50146835, 1.5624078 , 6.4890725 ,
0.93129397],
[ 0.73664398, -0.22247361, -0.05065276, 1.5096647 , 0.53926914,
1.3152293 ],
[ 0.17700275, -0.19150839, -0.50146835, 1.5788636 , 7.750705 ,
1.5559507 ]]),
'target': array([-1, -1, -1, ..., 1, 1, 1]),
'DESCR': 'mammography'}),
('protein_homo',
{'data': array([[ 52. , 32.69, 0.3 , ..., -0.35, 0.26, 0.76],
[ 58. , 33.33, 0. , ..., 1.16, 0.39, 0.73],
[ 77. , 27.27, -0.91, ..., -0.76, 0.26, 0.24],
...,
[100. , 71.76, 41.92, ..., 3.41, 0.44, 0.78],
[ 85.65, 26.46, 1.85, ..., 2.88, 0.54, 0.77],
[ 87.5 , 29.33, 5.84, ..., -0.58, 0.16, 0.23]]),
'target': array([-1, -1, -1, ..., 1, -1, 1]),
'DESCR': 'protein_homo'}),
('abalone_19',
{'data': array([[0. , 0. , 1. , ..., 0.2245, 0.101 , 0.15 ],
[0. , 0. , 1. , ..., 0.0995, 0.0485, 0.07 ],
[1. , 0. , 0. , ..., 0.2565, 0.1415, 0.21 ],
...,
[0. , 0. , 1. , ..., 0.5255, 0.2875, 0.308 ],
[1. , 0. , 0. , ..., 0.531 , 0.261 , 0.296 ],
[0. , 0. , 1. , ..., 0.9455, 0.3765, 0.495 ]]),
'target': array([-1, -1, -1, ..., -1, -1, -1]),
'DESCR': 'abalone_19'})])
wine_quality = datasets["wine_quality"]
data, target = wine_quality["data"], wine_quality["target"]
data, target
(array([[ 7. , 0.27, 0.36, ..., 3. , 0.45, 8.8 ],
[ 6.3 , 0.3 , 0.34, ..., 3.3 , 0.49, 9.5 ],
[ 8.1 , 0.28, 0.4 , ..., 3.26, 0.44, 10.1 ],
...,
[ 6.5 , 0.24, 0.19, ..., 2.99, 0.46, 9.4 ],
[ 5.5 , 0.29, 0.3 , ..., 3.34, 0.38, 12.8 ],
[ 6. , 0.21, 0.38, ..., 3.26, 0.32, 11.8 ]]),
array([-1, -1, -1, ..., -1, -1, -1]))
target = (target == 1).astype(int)
target
array([0, 0, 0, ..., 0, 0, 0])
data.shape, target.shape
((4898, 11), (4898,))
Counter(target)
Counter({0: 4715, 1: 183})
4715 / 183
25.76502732240437
columns = [
"fixed_acidity",
"volatile_acidity",
"citric_acid",
"residual_sugar",
"chlorides",
"free_sulfur_dioxide",
"total_sulfur_dioxide",
"density",
"pH",
"sulphates",
"alcohol",
]
df = pd.DataFrame(data, columns=columns)
df.head()
| fixed_acidity | volatile_acidity | citric_acid | residual_sugar | chlorides | free_sulfur_dioxide | total_sulfur_dioxide | density | pH | sulphates | alcohol | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 7.0 | 0.27 | 0.36 | 20.7 | 0.045 | 45.0 | 170.0 | 1.0010 | 3.00 | 0.45 | 8.8 |
| 1 | 6.3 | 0.30 | 0.34 | 1.6 | 0.049 | 14.0 | 132.0 | 0.9940 | 3.30 | 0.49 | 9.5 |
| 2 | 8.1 | 0.28 | 0.40 | 6.9 | 0.050 | 30.0 | 97.0 | 0.9951 | 3.26 | 0.44 | 10.1 |
| 3 | 7.2 | 0.23 | 0.32 | 8.5 | 0.058 | 47.0 | 186.0 | 0.9956 | 3.19 | 0.40 | 9.9 |
| 4 | 7.2 | 0.23 | 0.32 | 8.5 | 0.058 | 47.0 | 186.0 | 0.9956 | 3.19 | 0.40 | 9.9 |
X_train, X_test, y_train, y_test = train_test_split(df, target, test_size=0.3, random_state=42, shuffle=True)
X_train.shape, X_test.shape, y_train.shape, y_test.shape
((3428, 11), (1470, 11), (3428,), (1470,))
params = {
"random_state": 42,
"n_jobs": -1,
}
models = [
LogisticRegression(
max_iter=10_000,
**params
),
RandomForestClassifier(
n_estimators=50,
**params
),
XGBClassifier(
**params
),
# TabPFNClassifier(),
]
model_names = [
"LogisticRegression",
"RandomForestClassifier",
"XGBClassifier",
]
predictions = []
for model in tqdm(models):
model.fit(X_train.values, y_train)
prediction = model.predict(X_test.values)
predictions.append(prediction)
predictions
100%|██████████| 3/3 [00:00<00:00, 3.71it/s]
[array([0, 0, 0, ..., 0, 0, 0]), array([0, 0, 0, ..., 0, 0, 0]), array([0, 0, 0, ..., 0, 0, 0])]
metrics = [
accuracy_score,
precision_score,
recall_score,
f1_score,
]
summary = {}
for prediction, model_name in zip(predictions, model_names):
print(model_name)
print(Counter(prediction))
model_summary = {}
for metric in metrics:
print(metric.__name__)
score = metric(y_test, prediction)
model_summary[metric.__name__] = score
print(score)
summary[f"{model_name} {Counter(prediction)}"] = model_summary
print()
summary
LogisticRegression
Counter({0: 1470})
accuracy_score
0.9680272108843537
precision_score
0.0
recall_score
0.0
f1_score
0.0
RandomForestClassifier
Counter({0: 1460, 1: 10})
accuracy_score
0.9707482993197278
precision_score
0.7
recall_score
0.14893617021276595
f1_score
0.24561403508771928
XGBClassifier
Counter({0: 1452, 1: 18})
accuracy_score
0.9693877551020408
precision_score
0.5555555555555556
recall_score
0.2127659574468085
f1_score
0.3076923076923077
/home/karol/anaconda3/envs/xml/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result))
{'LogisticRegression Counter({0: 1470})': {'accuracy_score': 0.9680272108843537,
'precision_score': 0.0,
'recall_score': 0.0,
'f1_score': 0.0},
'RandomForestClassifier Counter({0: 1460, 1: 10})': {'accuracy_score': 0.9707482993197278,
'precision_score': 0.7,
'recall_score': 0.14893617021276595,
'f1_score': 0.24561403508771928},
'XGBClassifier Counter({0: 1452, 1: 18})': {'accuracy_score': 0.9693877551020408,
'precision_score': 0.5555555555555556,
'recall_score': 0.2127659574468085,
'f1_score': 0.3076923076923077}}
df_summary = pd.DataFrame(summary).T
df_summary
| accuracy_score | precision_score | recall_score | f1_score | |
|---|---|---|---|---|
| LogisticRegression Counter({0: 1470}) | 0.968027 | 0.000000 | 0.000000 | 0.000000 |
| RandomForestClassifier Counter({0: 1460, 1: 10}) | 0.970748 | 0.700000 | 0.148936 | 0.245614 |
| XGBClassifier Counter({0: 1452, 1: 18}) | 0.969388 | 0.555556 | 0.212766 | 0.307692 |
fig = px.bar(df_summary, barmode="group", title="Comparison of metrics on wine_quality dataset")
# fig.write_html("metrics.html")
fig.show()
import lime
from lime import lime_tabular
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
training_data=X_train.values,
feature_names=X_train.columns,
mode="classification"
)
def check_lime(idxs):
if isinstance(idxs, int):
idxs = [idxs]
for idx in idxs:
for model in models:
lime_explanation = lime_explainer.explain_instance(
data_row=X_test.iloc[idx],
predict_fn=lambda d: model.predict_proba(d)
)
print(f"Describing {model.__class__} for {idx} of label {y_test[idx]}")
lime_explanation.show_in_notebook()
check_lime(np.where(y_test == 0)[0][:3])
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 0 of label 0
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 0 of label 0
Describing <class 'xgboost.sklearn.XGBClassifier'> for 0 of label 0
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 1 of label 0
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 1 of label 0
Describing <class 'xgboost.sklearn.XGBClassifier'> for 1 of label 0
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 2 of label 0
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 2 of label 0
Describing <class 'xgboost.sklearn.XGBClassifier'> for 2 of label 0
By using lime we can see what features models consider as indicators of samples' label.
check_lime(np.where(y_test == 1)[0][:3])
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 7 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 7 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 7 of label 1
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 44 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 44 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 44 of label 1
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 72 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 72 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 72 of label 1
We can see that some features are highly associated with either label.
For example free_sulfur_dioxide for label 1 or chlorides with label 0.
Changing this values can influence models' decisions.
Get index of label 0
saved_idx = int(np.where(y_test == 1)[0][0])
check_lime(saved_idx)
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 7 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 7 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 7 of label 1
Models are sure that this data point is label 0
saved = X_test.iloc[saved_idx]
X_test.iloc[saved_idx][["free_sulfur_dioxide", "density", "alcohol"]] = 0
check_lime(saved_idx)
X_test.iloc[saved_idx] = saved
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 7 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 7 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 7 of label 1
We were able to fool LogisticRegression to think it is label 1 with 98% probability.